###########Code to initialize model primitives and estimands and compute value functions and other useful things

#structure to house model primitives
@with_kw struct Primitives
    β::Float64 = 0.95 #discount rate
    σ_ζℓ::Float64 = 1.0 #spread of location utility shocks (normalized)

    #spouse wage process by race
    λ_0s_vec::Array{Float64, 1} = [2.234, 2.243, 2.161]  #first term
    λ_1s_vec::Array{Float64, 1} = [0.571, 0.558, 0.512] #second term
    λ_2s_vec::Array{Float64, 1} = [0.047, 0.049, 0.0374] #third term
    λ_3s_vec::Array{Float64, 1} = [-0.0007, -0.0008, -0.0005] #fourth term

    #wage fixed effects for agent and spouse
    μ_l::Float64 = -0.413 #low fixed effect
    μ_h::Float64 = 0.414 #high fixed effect
    μ_ls::Float64 = -0.389 #low fixed effect, spouse
    μ_hs::Float64 = 0.391 #high fixed effect, spouse

    #grid spaces
    μ_grid::Array{Float64,1} = [0.0, 0.0, 0.0] #grids
    μs_grid::Array{Float64,1} = [0, μ_ls, μ_hs] #grids
    e_grid::Array{Int64,1} = [0,1] #college attainment, 0 or 1.
    m_grid::Array{Int64,1} = collect(0:1:2) #marital status, 0 for unmarried, 1 for low-quality spouse, 2 for high-quality spouse
    p_grid::Array{Int64,1} = [0,1] #prevous LFP status, 0 or 1.
    ℓ_grid::Array{Int64,1} = collect(1:1:10) #grid of locastions. 1 = parent. 2-5 = other (indexed by ll, hl, lh, hh) wage/ccc
    ac_grid::Array{Int64,1} = collect(-1:1:4) #age of youngest child. -1: no child present
    f_grid::Array{Int64,1} = [0,1] #pregnancy status, 0 or 1.
    ℓp_type_grid::Array{Float64,1} = collect(1:1:9) #type of parent location. (indexed by ll, hl, lh, hh) wage/ccc
    h_grid::Array{Int64,1} = [0, 1] #hours decisions
    τ_grid::Array{Int64, 1} = [0, 1]
    e_s_grid::Array{Int64, 1} = [0, 1] #spousal college attainment, 0 or 1

    #grid dimensions. dimensions that vary based on life-cycle period go last for ease
    n_μ::Int64 = length(μ_grid) #number of fixed effects
    n_μs::Int64 = length(μs_grid) #number of fixed effects, spouse
    n_e::Int64 = length(e_grid) #number of education states
    n_m::Int64 = length(m_grid) #number of marital states
    n_p::Int64 = length(p_grid) #number of previous LFP states
    n_ℓ::Int64 = length(ℓ_grid) #number of location options: parent, 4 combinations of others
    n_ac::Int64 = length(ac_grid) #number of child age states
    n_f::Int64 = length(f_grid) #number of fertilty states
    n_ℓp::Int64 = length(ℓp_type_grid) #number of possible parent location types
    n_h::Int64 = length(h_grid) #number of hours choices -- goes last
    n_τ::Int64 = length(τ_grid)
    n_e_s::Int64 = length(e_s_grid)

    #####age-experience grids and dimesions for different parts of the life-cycle
    #first -- anything can happen
    a_grid_1::Array{Int64,1} = collect(22:1:40) #age grid
    x_grid_1::Array{Int64,1} = collect(0:1:22) #experience grid
    n_a_1::Int64 = length(a_grid_1) #dimensions
    n_x_1::Int64 = length(x_grid_1)
    a_min_1::Int64 = 22
    a_max_1::Int64 = 40 #mins, maxes

    #second -- no more fertility or moving, but can still have young kids
    a_grid_2::Array{Int64,1} = collect(41:1:44) #age grid
    x_grid_2::Array{Int64,1} = collect(0:1:26) #experience grid
    n_a_2::Int64 = length(a_grid_2) #dimensions
    n_x_2::Int64 = length(x_grid_2)
    a_min_2::Int64 = 41
    a_max_2::Int64 = 44 #mins, maxes

    #third -- no more young kids
    a_grid_3::Array{Int64,1} = collect(45:1:65) #age grid
    x_grid_3::Array{Int64,1} = collect(0:1:40) #experience grid
    n_a_3::Int64 = length(a_grid_3) #dimensions
    n_x_3::Int64 = length(x_grid_3)
    a_min_3::Int64 = 45
    a_max_3::Int64 = 65 #mins, maxes

    #diviosn characterisitcs
    div_chars::Array{Float64, 2} = zeros(9, 9)

    #transition probabilities
    trans_probs::Array{Float64,2} = zeros(7,7) #preallocation. Will update later
end

#structure containing model estimands
@with_kw mutable struct Estimands
    guess::Vector{Float64} = zeros(20)

    #utility parameters
    α_1::Float64 = guess[1] #consumption
    α_2::Float64 = guess[2] #leisure
    α_3::Float64 = guess[3] #switching cost
    α_4::Float64 = guess[4] #parent preference
    α_5::Float64 = guess[5] #consumption, with kid
    α_6::Float64 = guess[6] #leisure, with kid
    α_7::Float64 = guess[3] #switching cost, with kid
    α_8::Float64 = guess[7] #parent preference, with kid
    α_9::Float64 = guess[8] #leisure/consumption complementarity
    α_10::Float64 = guess[9] #college leisure modifier
    α_μ::Float64 = guess[48] #leisure preference for low type
    α_x::Float64 = guess[24] #experience leisure effect
    μ_prob::Float64 = guess[49] #probability of high type

    #time transfer parameters
    τ_s::Float64 = guess[10] #time transfer, spouse
    τ_p::Float64 = guess[11] #time transfer, parent
    τ_p_2::Float64 = guess[12] #time transfer, parent

    #woman wage parameters
    λ_0::Float64 = guess[13] #constant
    λ_1::Float64 = guess[14] #education
    λ_2::Float64 = guess[15] #experience, linear, college
    λ_3::Float64 = guess[16] #experience, quad, college
    σ_ε::Float64 = guess[17] #wage shock
    σ_ξ::Float64 = guess[18] #measurement error
    
    #moving cost parameters
    γ_0::Float64 = guess[19] #fixed cost
    γ_1::Float64 = guess[20] #education
    γ_2::Float64 = guess[37] #age
    γ_3::Float64 = guess[21] #kids present
    γ_4::Float64 = guess[22] # married
    γ_5::Float64 = guess[23] #population

    #type probabilitieis
    τ_prob::Float64 = guess[25] #probability of low transfer type

    #more wage stuff
    λ_4::Float64 = guess[26] #dummy for kid age 0-1
    λ_5::Float64 = guess[27] #dummy for kid age 2-4
    λ_6::Float64 = guess[46] #college-eta interaction

    #amenities
    α_amen_1::Float64 = guess[28]
    α_amen_2::Float64 = guess[29]
    α_amen_3::Float64 = guess[30]

    #fertility preferences
    θ_1::Float64 = guess[31] #fixed cost
    θ_2::Float64 = guess[32] #marriage effect
    θ_3::Float64 = guess[33] #age effect
    θ_4::Float64 = guess[34] #age-marriage interaction
    θ_5::Float64 = guess[36] #age-marriage interaction
    σ_θ::Float64 = guess[35] #spread of shocks
end

#contains utilities and other results of this part of the code
mutable struct Results
    #phase 1
    emax_1::SharedArray{Float64, 11}
    cutoff_1::SharedArray{Float64, 12}

    #phase 2
    emax_2::SharedArray{Float64, 11}
    cutoff_2::SharedArray{Float64, 11}

    #phase 3
    emax_3::SharedArray{Float64, 9}
    cutoff_3::SharedArray{Float64, 9}

    #fertility shock cutoffs
    cutoff_1_fert::SharedArray{Float64, 11}
end

#utility function over consumption
function Utility(c)
    util = log(c)
    util
end

#function for getting interpolated index of a given Float with an arbitrary grid, assuming linear interpolation
function get_index(val::Float64, grid::Array{Float64,1})
    n = length(grid)
    index = 0 #preallocation
    if val<=grid[1] #LEQ smallest element
        index = 1
    elseif val>=grid[n] #GEQ biggest element
        index = n
    else
        index_upper = findfirst(z->z>val, grid)
        index_lower = index_upper - 1
        val_upper, val_lower = grid[index_upper], grid[index_lower] #values
        index = index_lower + (val - val_lower)  / (val_upper - val_lower) #weighted average
    end
    index #return
end

#Initialize model primitives
function Initialize(guess::Array{Float64,1}, trans_probs::Array{Float64,2}, div_chars::Array{Float64, 2})
    
    #get discretized mu distribution given guess of σ_μ
    #dist_μ = Normal(0, guess[47])
    #μ_val = quantile(dist_μ, 0.16666) #using equal mass discretization
    #μ_grid = [μ_val, 0.0, -1 * μ_val] ##symmetric around zero
    μ_val = guess[47]
    μ_grid = [0.0, μ_val]
    
    #inject guess of women's η terms into division characteristics
    η_women = [guess[38], guess[39], guess[40], 0.0, guess[41], guess[42], guess[43], guess[44], guess[45]]
    div_chars = hcat(div_chars, η_women)

    prim = Primitives(trans_probs = trans_probs, div_chars=div_chars, μ_grid = μ_grid) #initialize model primitives
    est = Estimands(guess=guess)

    #unpack primitives
    @unpack n_μ, n_e, n_m, n_p, n_ℓ, n_ac, n_f, n_ℓp, n_h, n_τ, n_e_s = prim
    @unpack n_a_1, n_x_1, n_a_2, n_x_2, n_a_3, n_x_3 = prim

    #initialize value function guessess
    emax_1 = SharedArray{Float64}(n_μ, n_e, n_m, n_p, n_ℓ, n_a_1, n_x_1, n_ac, n_ℓp, n_τ, n_e_s) #with everything.
    cutoff_1 = SharedArray{Float64}(n_μ, n_e, n_m, n_p, n_ℓ, n_a_1, n_x_1, n_ac, n_f, n_ℓp, n_τ, n_e_s) #with everything.
    cutoff_1_fert = SharedArray{Float64}(n_μ, n_e, n_m, n_p, n_ℓ, n_a_1, n_x_1, n_ac, n_ℓp, n_τ, n_e_s)

    emax_2 = SharedArray{Float64}(n_μ, n_e, n_m, n_p, n_ℓ, n_a_2, n_x_2, n_ac, n_ℓp, n_τ, n_e_s) #n_f omitted
    cutoff_2 = SharedArray{Float64}(n_μ, n_e, n_m, n_p, n_ℓ, n_a_2, n_x_2, n_ac, n_ℓp, n_τ, n_e_s) #n_f omitted

    emax_3 = SharedArray{Float64}(n_μ, n_e, n_m, n_p, n_ℓ, n_a_3, n_x_3, n_ℓp, n_e_s) #no more young kids
    cutoff_3 = SharedArray{Float64}(n_μ, n_e, n_m, n_p, n_ℓ, n_a_3, n_x_3, n_ℓp, n_e_s)
    res = Results(emax_1, cutoff_1, emax_2, cutoff_2, emax_3, cutoff_3, cutoff_1_fert) #initialize value funciton vectors
    prim, est, res #return deliverables
end

#backward induction protocol
function Backward_Induct(prim::Primitives, est::Estimands, res::Results; cfact::Int64 = 0, race::Int64 = 1)
    @unpack n_a_1, n_a_2, n_a_3, a_max_1, a_max_2, a_max_3, n_ℓp, n_e = prim

    #step 1: backward induction over third part of life-cycle
    @sync @distributed for i_ℓp = 1:n_ℓp
        for i = 1:n_a_3
            age = a_max_3 - i + 1 #now looping from 75 to 36, or whatever
            Compute_Valfunc_3(prim, est, res, age, i_ℓp; cfact=cfact, race = race) #fill in age i value functions for second life-cycle stage
        end

        #step 2: backward induction over second part of life-cycle
        for i = 1:n_a_2
            age = a_max_2 - i + 1 #now looping from 75 to 36, or whatever
            Compute_Valfunc_2(prim, est, res, age, i_ℓp; cfact=cfact, race = race) #fill in age i value functions for second life-cycle stage
        end

        #step 3: backward induction over first part of life-cycle
        for i = 1:n_a_1
            age = a_max_1 - i + 1 #now looping from 75 to 36, or whatever
            Compute_Valfunc_1(prim, est, res, age, i_ℓp; cfact = cfact, race = race) #fill in age i value functions for second life-cycle stage
        end        
    end
end

####function for computing probability of marriage state given other state
function Compute_Prob_Marr(mp::Int64, e_s_p::Int64, trans_states::Array{Int64,1}, prim::Primitives)
    @unpack trans_probs, a_grid_1, m_grid, f_grid, e_grid = prim
    prob = 0
    quals = [trans_probs[3,1] 1-trans_probs[3,1]] #probabilities of finding a bad/good spouse
    e, m, age, e_s = trans_states[1], trans_states[2], trans_states[3], trans_states[4]

    #begin computing probabiltiy
    if m + mp!=3 #transitions between 1-2 and 2-1 never happen
        if m == 0 #currently unmarried, so interested in probabilyt of marriage
            row = trans_probs[1 + e, :] #row 1 if no college, 2 if college
            prob = row[2] + row[3]*age

            #adjust probability of finding a particular kind of spouse
            if mp>0
                prob *= quals[mp] #spousal "unobserved" quality

                #spousal education
                if e_s_p == 1
                    prob *=row[1]
                elseif e_s_p == 0
                    prob *= (1-row[1])
                end


            elseif mp == 0 #adjust for probability of stayingin unmarried state
                prob = 1-prob

                #spousal education. Not marrying, but still need to do this to avoid double-adding things
                if e_s_p == 1
                    prob *=row[1]
                elseif e_s_p == 0
                    prob *= (1-row[1])
                end
            end
        elseif m!=0 && (e_s == e_s_p) #married, so interested in probability of divorce. Not considering spouses switching education
            row = trans_probs[1+e, :] #row 1 if no college, 2 if college
            prob = row[4]
            
            if mp==0 #reset if computing probability of divorce
                prob = 1-prob
            end
        end
    end

    if prob>1 #bounding
        prob = 1
    elseif prob<0
        prob = 0
    end
    prob #return
end

####function for computing probability of marriage state given other state
function Compute_Prob_Marr_old(mp::Int64, trans_states::Array{Int64,1}, prim::Primitives)
    @unpack trans_probs, a_grid_1, m_grid, f_grid, e_grid = prim
    prob = 0
    quals = [trans_probs[9,1] 1-trans_probs[9,1]] #probabilities of finding a bad/good spouse
    e, m, age = trans_states[1], trans_states[2], trans_states[3]

    #begin computing probabiltiy
    if m + mp!=3 #transitions between 1-2 and 2-1 never happen
        if m == 0 #currently unmarried, so interested in probabilyt of marriage
            row = trans_probs[5+e, :] #row 6 if no college, 5 if college
            prob = row[1] + row[2]*age

            #adjust probability of finding a particular kind of spouse
            if mp>0
                prob *= quals[mp]
            elseif mp == 0 #adjust for probability of stayingin unmarried state
                prob = 1-prob
            end
        elseif m!=0 #married, so interested in probability of divorce
            row = trans_probs[7+e, :] #row 6 if no college, 5 if college
            prob = row[1] + row[2]*age
            if mp==0 #reset if computing probability of divorce
                prob = 1-prob
            end
        end
    end

    if prob>1 #bounding
        prob = 1
    elseif prob<0
        prob = 0
    end
    prob #return
end
